package it.uniroma2.tk;

import it.uniroma2.util.tree.Tree;

import java.util.HashMap;
import java.util.List;
import java.util.Vector;

/**
 * @author Lorenzo Dell'Arciprete
 *
 * This class implements the Partial Tree Kernel computation between Tree objects 
 */
public class PartialTreeKernel {

	public static double lambda = 1;
	public static double mu = 1;
	private static int nodeCount = 0;
	private static HashMap<String,Double> deltaMatrix;
	private static HashMap<String,Double> deltaPMatrix;
	private static HashMap<String,Double> dPMatrix;
	private static HashMap<Tree,Integer> nodeIndices; 
	
	public static double value(Tree a, Tree b) {
		deltaMatrix = new HashMap<String,Double>();
		deltaPMatrix = new HashMap<String,Double>();
		dPMatrix = new HashMap<String,Double>();
		nodeIndices = new HashMap<Tree,Integer>();
		nodeCount = 0;
		double sum = 0;
		for (Tree aa : allNodes(a))
			for (Tree bb : allNodes(b))
				sum += delta(aa,bb);
		return sum;
	}

	private static double delta(Tree a,Tree b) {
		if (!a.getRootLabel().equals(b.getRootLabel()))
			return 0;
		if (deltaMatrix.containsKey(nodeIndices.get(a) + ":" +nodeIndices.get(b))) {
			return deltaMatrix.get(nodeIndices.get(a) + ":" +nodeIndices.get(b));
		}
		
		double k = 0;
		int lm = Math.min(a.getChildren().size(), b.getChildren().size());
		k = lambda*lambda;
		for (int p=1; p<=lm; p++)
			k += deltaP(p, a.getChildren(), b.getChildren());
		k = mu*k;
		deltaMatrix.put(nodeIndices.get(a) + ":" +nodeIndices.get(b),k);
		return k;
	}
	
	private static double deltaP(int p, List<Tree> c1, List<Tree> c2) {
		if (Math.min(c1.size(), c2.size()) < p)
			return 0;
		String key = String.valueOf(p);
		for (Tree t : c1) {
			key += ":"+nodeIndices.get(t);
		}
		for (Tree t : c2) {
			key += ";"+nodeIndices.get(t);
		}
		if (deltaPMatrix.containsKey(key))
			return deltaPMatrix.get(key);
		
		double res = deltaP(p, c1.subList(0, c1.size()-1), c2);
		Tree last = c1.get(c1.size()-1);
		for (Tree n : c2) {
			if (n.getRootLabel().equals(last.getRootLabel()))
				res += delta(last, n) * DP(p-1, c1.subList(0, c1.size()-1), c2.subList(0, c2.indexOf(n)));
		}
		deltaPMatrix.put(key, res);
		return res;
	}
	
	private static double DP(int p, List<Tree> c1, List<Tree> c2) {
		if (p == 0)
			return 1;
		else if (Math.min(c1.size(), c2.size()) < p)
			return 0; 
		String key = String.valueOf(p);
		for (Tree t : c1) {
			key += ":"+nodeIndices.get(t);
		}
		for (Tree t : c2) {
			key += ";"+nodeIndices.get(t);
		}
		if (dPMatrix.containsKey(key))
			return dPMatrix.get(key);
		
		double res = lambda * DP(p, c1.subList(0, c1.size()-1), c2);
		Tree last = c1.get(c1.size()-1);
		for (Tree n : c2) {
			if (n.getRootLabel().equals(last.getRootLabel()))
				res += Math.pow(lambda, (c2.size() - c2.indexOf(n) + 1)) * 
					delta(last, n) * DP(p-1, c1.subList(0, c1.size()-1), c2.subList(0, c2.indexOf(n)));
		}
		dPMatrix.put(key, res);
		return res;
	}
	
	private static Vector<Tree> allNodes(Tree node) {
		if (!nodeIndices.containsKey(node)) {
			nodeIndices.put(node,nodeCount);
			nodeCount++;
		}
		Vector<Tree> all = new Vector<Tree>();
		all.add(node);
		for (Tree child : node.getChildren())
			all.addAll(allNodes(child));
		return all;
	}
	
}
